"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""
""" Create LIME explanations -
    - For various environments
    - For "combined" environments (will use base perturbations when bootstrap samples are 
        used to create environments)
    - median of various environments """


import numpy as np
import sys
sys.path.append("../utilities/")
import os
from time import time

# from joblib import Parallel, delayed
# from sklearn.utils import check_random_state
import yaml
import pickle

# fname_lime_exp
from utils import (fname_env_perts, fname_exp,
                    fname_base_perts, fname_preds, compute_weights,
                    lime_explanation, create_dir_if_not_exist)

# Pass arguments and run the code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_fname")
parser.add_argument("--dataset_key")
parser.add_argument("--model_key")
parser.add_argument("--pert_key")
parser.add_argument("--maple_leaves")
args = parser.parse_args()
args.maple_leaves = int(args.maple_leaves)

# Load the config file
config = yaml.load(open(
            os.path.join("config", args.config_fname)),
            Loader=yaml.FullLoader)

# Load the data
dirname = os.path.join("data", args.dataset_key, "perturbations")
env_pert_fname = fname_env_perts(config, "Env_Perturbations",
                    args.pert_key, args.dataset_key)+".pkl"
env_perturbations = pickle.load(open( os.path.join(dirname, env_pert_fname), "rb" ) )

base_pert_fname = fname_base_perts(config, 
                    args.pert_key, args.dataset_key)+".pkl"
base_perturbations = pickle.load(open( 
                        os.path.join(dirname, base_pert_fname), "rb" ) )

dirname = os.path.join("data", args.dataset_key, "predictions")
preds_fname = fname_preds(config, args.pert_key, args.model_key, args.dataset_key)+".pkl"
y_pred_perts = pickle.load(open( os.path.join(dirname, preds_fname), "rb" ) )

# Compute coeffs for base as well as individual envs
n_data_all = len(base_perturbations["indices"])
n_perts_env = config["Env_Perturbations"]["cnt"]
num_envs = config["Env_Perturbations"]["num_envs"]
p = base_perturbations["samp_perts_exp"][0].shape[1]
samp_perts_exp = base_perturbations["samp_perts_exp"]
kernel_width = config["Env_Perturbations"]["kernel_width"]
samp_inds_env = env_perturbations["samp_inds_env"]
normalize_weights = config["Env_Perturbations"]["normalize_weights"]
num_nonzeros = config["LIME"]["non_zeros"]

## For MAPLE VVV
if args.pert_key == "MAPLE":

    from sklearn.ensemble import RandomForestRegressor
    from sklearn.preprocessing import OneHotEncoder
    
    X_test = np.vstack([x[0] for x in base_perturbations["samp_perts_exp"]])
    y_test = np.array([y[0] for y in y_pred_perts])
    from time import time
    def train_maple_weights(X_test, y_test):
        rfr = RandomForestRegressor(min_samples_leaf=args.maple_leaves, 
                                    random_state=int(time()))
        rfr.fit(X_test, y_test)
        return rfr
    
    # RF Regressor for MAPLE weights
    rfr = train_maple_weights(X_test, y_test)
    
    def compute_maple_weights(x):
        leaves = rfr.apply(x)
        # leaves.shape
        onehot_cats = [list(range(est.tree_.node_count)) 
                       for est in rfr.estimators_]
        leaf_enc = OneHotEncoder().fit(leaves)
        M = leaf_enc.transform(leaves)
        S = (M*M.transpose()).todense().A + 1e-5

        return S/np.max(S)

st_time = time()
lime_base = np.zeros((n_data_all, p))
for (idx, (y_pred0, samp_pert_exp0)) in enumerate(
                    zip(y_pred_perts,
                        base_perturbations["samp_perts_exp"],
                        )):
    samp_pert_exp1 = samp_pert_exp0[0:n_perts_env]
    y_pred = y_pred0[0:n_perts_env]

    # Remove all zero columns
    nzinds1 = ~np.all(samp_pert_exp1 == 0, axis=0)
    samp_pert_exp = samp_pert_exp1[:, nzinds1]

    # print(samp_pert_exp0.shape, samp_pert_exp.shape, y_pred0.shape, y_pred.shape)

    if args.pert_key == "MAPLE":
        local_weights = compute_maple_weights(samp_pert_exp1)[0]
        if normalize_weights:
            local_weights = local_weights/local_weights.sum()
    else:
        local_weights = compute_weights(
                samp_pert_exp, 
                distance_metric="euclidean", 
                kernel_width=kernel_width,
                normalize=normalize_weights)

    # Explanation based on above weights
    lime_base[idx][nzinds1] = lime_explanation(samp_pert_exp,
                             y_pred, 
                             local_weights.ravel(), 
                             num_nonzeros=num_nonzeros,
                             debias=True)
lime_base_time = time()-st_time

# For envs
st_time = time()
lime_envs = np.zeros((n_data_all, p, num_envs))
for idx in range(n_data_all):

    y_pred = y_pred_perts[idx]
    samp_pert_exp0 = base_perturbations["samp_perts_exp"][idx]

    # Remove all zero columns and columns with coefs 
    nzinds1 = ~np.all(samp_pert_exp0 == 0, axis=0)
    samp_pert_exp = samp_pert_exp0[:, nzinds1]

    for env in range(num_envs):
        samp_pert_exp_env = samp_pert_exp[samp_inds_env[idx, :, env], :]
        y_pred_env = y_pred[samp_inds_env[idx, :, env]]

        # print(samp_pert_exp_env.shape, y_pred_env.shape)
        if args.pert_key == "MAPLE":
            local_weights_env = compute_maple_weights(samp_pert_exp0[samp_inds_env[
                                                    idx, :, env], :])[0]
            local_weights_env = local_weights_env/local_weights_env.sum()
        else:
        # if args.pert_key == "Base_Perturbations":
            local_weights_env = compute_weights(
                        samp_pert_exp_env, 
                        distance_metric="euclidean", 
                        kernel_width=kernel_width,
                        normalize=normalize_weights)
             
        # Environment-wise explanations
        lime_envs[idx, nzinds1, env] = lime_explanation(samp_pert_exp_env,
                                y_pred_env, 
                                local_weights_env.ravel(), 
                                num_nonzeros=num_nonzeros,
                                debias=True)
        
lime_envs_time = time()-st_time

# dump the explanations
exp_fname = fname_exp(config, "LIME", "Env_Perturbations",
                    args.pert_key, args.model_key,
                    args.dataset_key)+".pkl"

lime_explanations = {"lime_base": lime_base,
                    "lime_envs": lime_envs,
                    "lime_base_time": lime_base_time,
                    "lime_envs_time": lime_envs_time}
dirname = os.path.join("data", args.dataset_key, "explanations")
create_dir_if_not_exist(dirname)

pickle.dump(lime_explanations, 
    open( os.path.join(dirname, exp_fname), "wb" ) )

print(lime_base.shape)
print(lime_envs.shape)
print(os.path.join(dirname, exp_fname))
